from transformers import GPT2LMHeadModel, GPT2Tokenizer, AutoModelForCausalLM, AutoTokenizer, AutoModelForSeq2SeqLM
import torch 
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np 
# from pdb import set_trace
import time
# from pdb import set_trace




def calculate_sequence_probability(model, tokenizer, input_sequence, output_sequence, encoded_sequence, device, remove_start_token = False):
    """
    Calculate the probability of the output sequence given the input sequence using a LLM. 

    Args:
    - model:  The LLM model to be used 
    - tokenizer: The tokenizer to be used
    - input_sequence (str): The input text sequence.
    - output_sequence (str): The output text sequence for which the probability is to be calculated.
    - encoded_sequence (bool): Whether the input and output sequences are already encoded. If True, the input and output sequences are assumed to be encoded using the tokenizer.
    - device: The device to be used for computation (e.g. "cpu" or "cuda").
    - remove_start_token (bool): Whether to remove the start token from the output sequence. Depends on the model! For GPT-2 set to true, for flan set to faluse

    Returns:
    - float: The probability of the output sequence given the input sequence.
    - float: The log probability of the output sequence given the input sequence.
    """


    # Encode the sequences
    if not encoded_sequence:
        input_ids = tokenizer.encode(input_sequence, return_tensors="pt")
        output_ids = tokenizer.encode(output_sequence, return_tensors="pt")
    else:
        input_ids = input_sequence
        # need to reshape the output sequence to be of shape (1, len(output_sequence))
        output_ids = output_sequence.view(1, -1)

    input_ids = input_ids.to(device)
    output_ids = output_ids.to(device)

    # Concatenate input and output sequences, ensuring not to exceed model's max length
    if remove_start_token:
        sequence = torch.cat((input_ids, output_ids[:, 1:]), dim=1)  # Remove the duplicated start token
    else:
        sequence = torch.cat((input_ids, output_ids), dim=1)  

    
    # print the full sequence
    # print('Full sequence:', tokenizer.decode(sequence[0]))

    # if sequence.size(1) > model.config.n_positions:
    #     raise ValueError(f"Sequence length exceeds model's maximum length: {model.config.n_positions}")

    # Pass sequences through the model
    with torch.no_grad():
        outputs = model(sequence, labels=sequence)
        logits = outputs.logits


    # Calculate the probability of the output sequence
    sequence_probability = 1.0
    sequence_log_prob = 0.0
    token_log_prob_list = []
    for i, token_id in enumerate(output_ids[0][1:]):  # Skip the first token of the output sequence
        token_logits = logits[0, input_ids.size(1) + i]
        token_log_prob = F.log_softmax(token_logits, dim=-1)[token_id].item()
        sequence_log_prob += token_log_prob
        token_log_prob_list.append(token_log_prob)
        token_probability = torch.exp(torch.tensor(token_log_prob)).item()
        sequence_probability *= token_probability

    # print('Stable probability and log probability:', sequence_probability, sequence_log_prob)
    # set_trace()

    return sequence_probability, sequence_log_prob


def sample_sentences(model, tokenizer, input_sequence, device, min_length = 10, max_length = 100, top_k=50, top_p=0.95, temperature = 0.8, 
                    num_return_sequences=10, batch_size = 10,print_output = True): 
    """
    Sample sentences from a LLM model given an input sequence.

    Args:
    - model: The LLM model to be used 
    - tokenizer: The tokenizer to be used
    - input_sequence (str): The input text sequence.
    - max_length (int): The maximum length of the output sequence.
    - top_k (int): The number of highest probability vocabulary tokens to keep for top-k-filtering.
    - top_p (float): The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling.
    - num_return_sequences (int): The number of output sequences to generate for each input sequence.

    Returns:
    - list: The list of output sequences, encoded. The input sequence is not included in the output sequence! (i.e., this is not a continuation of the input sequence. llama would also by default return the input sequence too)
    """


    # Encode the input sequence
    input_ids = tokenizer.encode(input_sequence, return_tensors="pt")
    input_ids = input_ids.to(device)

    # Generate output sequences using sampling
    # Turn off gradient calculation to speed up the process
    total_output_sequences = []
    num_batches = int(np.ceil(num_return_sequences / batch_size))

    effective_min_length = min_length + len(input_ids[0]) # Ensure at least min_length tokens in the output (adjust this value as needed)
    
    with torch.no_grad():
        for i in range(num_batches):
            print('---> Generating batch: ', i)
            start = time.time()
            sequences_to_generate = min(batch_size, num_return_sequences - len(total_output_sequences))
            output_sequences = model.generate(
                input_ids=input_ids,
                max_length=max_length,
                min_length=effective_min_length,
                do_sample=True,
                top_k=top_k,
                top_p=top_p,
                temperature=temperature,
                num_return_sequences= sequences_to_generate,
            )
            end = time.time()
            print(f'Time taken for {i}-th batch: ', end - start)
            total_output_sequences.extend(output_sequences)



    # if the output sequences also includes the input sequence, remove it from the output sequence
    input_ids_reshaped = input_ids.view(-1) # make it a 1D tensor
    input_sequence_length = input_ids_reshaped.size(0)
    adjusted_output_sequences = []
    
    for i in range(len(total_output_sequences)):
        if total_output_sequences[i][:input_sequence_length].equal(input_ids_reshaped):
            sequence_pruned = total_output_sequences[i][input_sequence_length:]
            adjusted_output_sequences.append(sequence_pruned)
        else:
            adjusted_output_sequences.append(total_output_sequences[i])


    # Decode and print the output sequences
    if print_output:
        print("Output:\n" + 100 * "-")
        for i, output_sequence in enumerate(adjusted_output_sequences):
            print("{}: {}".format(i, tokenizer.decode(output_sequence, skip_special_tokens=True)))

    # Convert the adjusted output sequence back to a tensor. 
    # # This is useful if you want to calculate the probability of the output sequence given the input sequence
    # adjusted_output_sequences = torch.stack(adjusted_output_sequences)
    # # set_trace()
    # set_trace()

    return adjusted_output_sequences


def plot_histogram(input_numbers):
    """
    Plot a histogram of the input numbers.

    Args:
    - input_numbers (list): The list of numbers to plot.
    """

    # Plot histogram
    plt.hist(input_numbers, bins=100)
    plt.xlabel("Log Probability")
    plt.ylabel("Frequency")
    plt.title("Histogram of log probabilities")
    plt.show()



if __name__ == "__main__":
    # # Load model and tokenizer (GPT-2)
    # model_name = "gpt2-medium"  # or another model
    # model_gpt2 = GPT2LMHeadModel.from_pretrained(model_name)
    # tokenizer_gpt2 = GPT2Tokenizer.from_pretrained(model_name)

    # Load model and tokenizer (flan)
    model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-large")
    tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-large")

  
    # Create original user input sequence and output sequence
    user_input_sequence = "Create a 5-day itenerary for a trip to New York City."
    rewrite_prompt = "Rewrite this sentence: "
    original_output_sequence = sample_sentences(model = model, tokenizer = tokenizer, input_sequence = user_input_sequence, 
                                        max_length = 100, top_k=50,top_p=0.95,num_return_sequences=1, print_output= False)
    
    # Calculate probability of original output sequence
    print('-'*50)
    print("Original sequence: {}".format(tokenizer.decode(original_output_sequence[0], skip_special_tokens=True)))
    original_probability, original_log_probability = calculate_sequence_probability(model = model, tokenizer = tokenizer, 
            input_sequence = user_input_sequence, output_sequence = tokenizer.decode(original_output_sequence[0], skip_special_tokens=True))
    

    # Rephrase the output sequence using the "proposal distribution"
    proposal_distribution_sequence = rewrite_prompt + tokenizer.decode(original_output_sequence[0], skip_special_tokens=True)
    # set_trace()
    output_sequences = sample_sentences(model = model, tokenizer = tokenizer, input_sequence = proposal_distribution_sequence,
                                        max_length = 100, top_k=50,top_p=0.95,num_return_sequences=100, print_output= False)
    
    
    # Calculate probability of each output sequence
    probabilities = []
    log_probabilities = []
    for i, output_sequence in enumerate(output_sequences):
        print('-'*50)
        print("Sequence {}: {}".format(i, tokenizer.decode(output_sequence, skip_special_tokens=True)))
        probability, log_probability = calculate_sequence_probability(model = model, tokenizer = tokenizer, input_sequence = user_input_sequence, output_sequence = tokenizer.decode(output_sequence, skip_special_tokens=True),
                                        remove_start_token = False)
        
        probabilities.append(probability)
        log_probabilities.append(log_probability)

    # Plot histogram of probabilities
    # plot_histogram(probabilities)
    # plot_histogram(log_probabilities)

    print('-'*50)
    print('Original probability:', original_probability, 'Original log probability:', original_log_probability)
    print('Mean proposal probability:', np.mean(probabilities), 'Mean log probability:', np.mean(log_probabilities))

    # set_trace()


    # set_trace()